#
# Copyright 2016 Quantopian, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from math import copysign
import warnings
from collections import deque, OrderedDict

import pandas as pd
import numpy as np

from .utils import print_table, format_asset

PNL_STATS = OrderedDict(
    [('Total profit', lambda x: x.sum()),
     ('Gross profit', lambda x: x[x > 0].sum()),
     ('Gross loss', lambda x: x[x < 0].sum()),
     ('Profit factor', lambda x: x[x > 0].sum() / x[x < 0].abs().sum()
     if x[x < 0].abs().sum() != 0 else np.nan),
     ('Avg. trade net profit', 'mean'),
     ('Avg. winning trade', lambda x: x[x > 0].mean()),
     ('Avg. losing trade', lambda x: x[x < 0].mean()),
     ('Ratio Avg. Win:Avg. Loss', lambda x: x[x > 0].mean() /
                                            x[x < 0].abs().mean() if x[x < 0].abs().mean() != 0 else np.nan),
     ('Largest winning trade', 'max'),
     ('Largest losing trade', 'min'),
     ])

SUMMARY_STATS = OrderedDict(
    [('Total number of round_trips', 'count'),
     ('Percent profitable', lambda x: len(x[x > 0]) / float(len(x))),
     ('Winning round_trips', lambda x: len(x[x > 0])),
     ('Losing round_trips', lambda x: len(x[x < 0])),
     ('Even round_trips', lambda x: len(x[x == 0])),
     ])

RETURN_STATS = OrderedDict(
    [('Avg returns all round_trips', lambda x: x.mean()),
     ('Avg returns winning', lambda x: x[x > 0].mean()),
     ('Avg returns losing', lambda x: x[x < 0].mean()),
     ('Median returns all round_trips', lambda x: x.median()),
     ('Median returns winning', lambda x: x[x > 0].median()),
     ('Median returns losing', lambda x: x[x < 0].median()),
     ('Largest winning trade', 'max'),
     ('Largest losing trade', 'min'),
     ])

DURATION_STATS = OrderedDict(
    [('Avg duration', lambda x: x.mean()),
     ('Median duration', lambda x: x.median()),
     ('Longest duration', lambda x: x.max()),
     ('Shortest duration', lambda x: x.min())
     #  FIXME: Instead of x.max() - x.min() this should be
     #  rts.close_dt.max() - rts.open_dt.min() which is not
     #  available here. As it would require a new approach here
     #  that passes in multiple fields we disable these measures
     #  for now.
     #  ('Avg # round_trips per day', lambda x: float(len(x)) /
     #   (x.max() - x.min()).days),
     #  ('Avg # round_trips per month', lambda x: float(len(x)) /
     #   (((x.max() - x.min()).days) / APPROX_BDAYS_PER_MONTH)),
     ])


# def agg_all_long_short(round_trips, col, stats_dict):
#     # Aggregating for all trades
#     print("stats_dict = ", stats_dict)
#     stats_all = (round_trips
#                  .assign(ones=1)
#                  .groupby('ones')[col])
#     stats_all = stats_all.agg(stats_dict)
#     stats_all = stats_all.T.rename(columns={1.0: 'All trades'})
#
#     # Aggregating for long and short trades
#     # Use `rename(columns=...)` instead of `rename_axis`
#     stats_long_short = (round_trips
#                         .groupby('long')[col])
#     stats_long_short = stats_long_short.agg(stats_dict)
#     stats_long_short = stats_long_short.T.rename(columns={False: 'Short trades', True: 'Long trades'})
#
#     # Join the two results
#     return stats_all.join(stats_long_short)

def agg_all_long_short(round_trips, col, stats_dict):
    # Separate custom functions from built-in functions
    custom_funcs = {k: v for k, v in stats_dict.items() if callable(v)}
    built_in_funcs = [v for k, v in stats_dict.items() if not callable(v)]

    # Aggregating for all trades
    stats_all = (round_trips
                 .assign(ones=1)
                 .groupby('ones')[col])

    # Apply custom functions manually
    stats_all_custom = {}
    for func_name, func in custom_funcs.items():
        stats_all_custom[func_name] = stats_all.apply(func)
    stats_all_custom = pd.DataFrame(stats_all_custom)

    # Apply built-in functions
    stats_all_built_in = stats_all.agg(built_in_funcs)

    # Combine results
    stats_all = pd.concat([stats_all_custom, stats_all_built_in], axis=1)
    stats_all = stats_all.T.rename(columns={1.0: 'All trades'})

    # Aggregating for long and short trades
    stats_long_short = (round_trips
                        .groupby('long')[col])

    # Apply custom functions manually
    stats_long_short_custom = {}
    for func_name, func in custom_funcs.items():
        stats_long_short_custom[func_name] = stats_long_short.apply(func)
    stats_long_short_custom = pd.DataFrame(stats_long_short_custom)

    # Apply built-in functions
    stats_long_short_built_in = stats_long_short.agg(built_in_funcs)

    # Combine results
    stats_long_short = pd.concat([stats_long_short_custom, stats_long_short_built_in], axis=1)
    stats_long_short = stats_long_short.T.rename(columns={False: 'Short trades', True: 'Long trades'})

    # Join the two results
    return stats_all.join(stats_long_short)


def _groupby_consecutive(txn, max_delta=pd.Timedelta('8h')):
    """Merge transactions of the same direction separated by less than
    max_delta time duration.

    Parameters
    ----------
    transactions : pd.DataFrame
        Prices and amounts of executed round_trips. One row per trade.
        - See full explanation in tears.create_full_tear_sheet

    max_delta : pandas.Timedelta (optional)
        Merge transactions in the same direction separated by less
        than max_delta time duration.


    Returns
    -------
    transactions : pd.DataFrame

    """

    def vwap(transaction):
        if transaction.amount.sum() == 0:
            warnings.warn('Zero transacted shares, setting vwap to nan.')
            return np.nan
        return (transaction.amount * transaction.price).sum() / \
            transaction.amount.sum()

    out = []
    for sym, t in txn.groupby('symbol'):
        t = t.sort_index()
        t.index.name = 'dt'
        t.index = pd.to_datetime(t.index)
        t = t.reset_index()

        t['order_sign'] = t.amount > 0
        t['block_dir'] = (t.order_sign.shift(
            1) != t.order_sign).astype(int).cumsum()
        t['block_time'] = ((t.dt - t.dt.shift(1)) > max_delta).astype(int).cumsum()
        # grouped_price = (t.groupby(('block_dir',
        #                            'block_time'))
        #                   .apply(vwap))
        # grouped_price = t.groupby(['block_dir', 'block_time']).apply(vwap)
        grouped_price = t.groupby(['block_dir', 'block_time'])[['amount', 'price']].apply(vwap)
        grouped_price.name = 'price'
        grouped_rest = t.groupby(['block_dir', 'block_time']).agg({
            'amount': 'sum',
            'symbol': 'first',
            'dt': 'first'})

        grouped = grouped_rest.join(grouped_price)

        out.append(grouped)

    out = pd.concat(out)
    out = out.set_index('dt')
    return out


def extract_round_trips(transactions,
                        portfolio_value=None):
    """Group transactions into "round trips". First, transactions are
    grouped by day and directionality. Then, long and short
    transactions are matched to create round-trip round_trips for which
    PnL, duration and returns are computed. Crossings where a position
    changes from long to short and vice versa are handled correctly.

    Under the hood, we reconstruct the individual shares in a
    portfolio over time and match round_trips in a FIFO order.

    For example, the following transactions would constitute one round trip:
    index                  amount   price    symbol
    2004-01-09 12:18:01    10       50      'AAPL'
    2004-01-09 15:12:53    10       100      'AAPL'
    2004-01-13 14:41:23    -10      100      'AAPL'
    2004-01-13 15:23:34    -10      200       'AAPL'

    First, the first two and last two round_trips will be merged into two
    single transactions (computing the price via vwap). Then, during
    the portfolio reconstruction, the two resulting transactions will
    be merged and result in 1 round-trip trade with a PnL of
    (150 * 20) - (75 * 20) = 1500.

    Note that round trips do not have to close out positions
    completely. For example, we could have removed the last
    transaction in the example above and still generated a round-trip
    over 10 shares with 10 shares left in the portfolio to be matched
    with a later transaction.

    Parameters
    ----------
    transactions : pd.DataFrame
        Prices and amounts of executed round_trips. One row per trade.
        - See full explanation in tears.create_full_tear_sheet

    portfolio_value : pd.Series (optional)
        Portfolio value (all net assets including cash) over time.
        Note that portfolio_value needs to beginning of day, so either
        use .shift() or positions.sum(axis='columns') / (1+returns).

    Returns
    -------
    round_trips : pd.DataFrame:
        DataFrame with one row per round trip.  The `returns` column
        contains returns in respect to the portfolio value while
        rt_returns are the returns in regard to the invested capital
        into that particular round-trip.
    """

    transactions = _groupby_consecutive(transactions)
    roundtrips = []

    for sym, trans_sym in transactions.groupby('symbol'):
        trans_sym = trans_sym.sort_index()
        price_stack = deque()
        dt_stack = deque()
        trans_sym['signed_price'] = trans_sym.price * np.sign(trans_sym.amount)
        trans_sym['abs_amount'] = trans_sym.amount.abs().astype(int)
        for dt, t in trans_sym.iterrows():
            if t.price < 0:
                warnings.warn('Negative price detected, ignoring for'
                              'round-trip.')
                continue

            indiv_prices = [t.signed_price] * t.abs_amount
            if (len(price_stack) == 0) or \
                    (copysign(1, price_stack[-1]) == copysign(1, t.amount)):
                price_stack.extend(indiv_prices)
                dt_stack.extend([dt] * len(indiv_prices))
            else:
                # Close round-trip
                pnl = 0
                invested = 0
                cur_open_dts = []

                for price in indiv_prices:
                    if len(price_stack) != 0 and \
                            (copysign(1, price_stack[-1]) != copysign(1, price)):
                        # Retrieve the first dt, stock-price pair from
                        # stack
                        prev_price = price_stack.popleft()
                        prev_dt = dt_stack.popleft()

                        pnl += -(price + prev_price)
                        cur_open_dts.append(prev_dt)
                        invested += abs(prev_price)

                    else:
                        # Push additional stock prices onto the stack
                        price_stack.append(price)
                        dt_stack.append(dt)

                roundtrips.append({'pnl': pnl,
                                   'open_dt': cur_open_dts[0],
                                   'close_dt': dt,
                                   'long': price < 0,
                                   'rt_returns': pnl / invested,
                                   'symbol': sym,
                                   })

    roundtrips = pd.DataFrame(roundtrips)

    roundtrips['duration'] = roundtrips['close_dt'].sub(roundtrips['open_dt'])

    if portfolio_value is not None:
        # Need to normalize so that we can join
        pv = pd.DataFrame(portfolio_value,
                          columns=['portfolio_value']) \
            .assign(date=portfolio_value.index)

        roundtrips['date'] = roundtrips.close_dt.apply(lambda x:
                                                       x.replace(hour=0,
                                                                 minute=0,
                                                                 second=0))
        # Convert 'roundtrips.date' to UTC to match 'portfolio_value.index'
        if pv.index.tz is not None:  # portfolio_value.index has a timezone (e.g., UTC)
            roundtrips['date'] = roundtrips['date'].dt.tz_localize('UTC')

        tmp = roundtrips.join(pv, on='date', lsuffix='_')

        roundtrips['returns'] = tmp.pnl / tmp.portfolio_value
        roundtrips = roundtrips.drop('date', axis='columns')

    return roundtrips


def add_closing_transactions(positions, transactions):
    """
    Appends transactions that close out all positions at the end of
    the timespan covered by positions data. Utilizes pricing information
    in the positions DataFrame to determine closing price.

    Parameters
    ----------
    positions : pd.DataFrame
        The positions that the strategy takes over time.
    transactions : pd.DataFrame
        Prices and amounts of executed round_trips. One row per trade.
        - See full explanation in tears.create_full_tear_sheet

    Returns
    -------
    closed_txns : pd.DataFrame
        Transactions with closing transactions appended.
    """

    closed_txns = transactions[['symbol', 'amount', 'price']]

    pos_at_end = positions.drop('cash', axis=1).iloc[-1]
    open_pos = pos_at_end.replace(0, np.nan).dropna()
    # Add closing round_trips one second after the close to be sure
    # they don't conflict with other round_trips executed at that time.
    end_dt = open_pos.name + pd.Timedelta(seconds=1)

    for sym, ending_val in open_pos.items():
        txn_sym = transactions[transactions.symbol == sym]

        ending_amount = txn_sym.amount.sum()

        ending_price = ending_val / ending_amount
        closing_txn = {'symbol': sym,
                       'amount': -ending_amount,
                       'price': ending_price}

        closing_txn = pd.DataFrame(closing_txn, index=[end_dt])
        # closed_txns = closed_txns.append(closing_txn)
        closed_txns = pd.concat([closed_txns, closing_txn], ignore_index=True)

    closed_txns = closed_txns[closed_txns.amount != 0]

    return closed_txns


def apply_sector_mappings_to_round_trips(round_trips, sector_mappings):
    """
    Translates round trip symbols to sectors.

    Parameters
    ----------
    round_trips : pd.DataFrame:
        DataFrame with one row per-round-trip trade.
        - See full explanation in round_trips.extract_round_trips
    sector_mappings : dict or pd.Series, optional
        Security identifier to sector mapping.
        Security ids as keys, sectors as values.

    Returns
    -------
    sector_round_trips : pd.DataFrame
        Round trips with symbol names replaced by sector names.
    """

    sector_round_trips = round_trips.copy()
    sector_round_trips.symbol = sector_round_trips.symbol.apply(
        lambda x: sector_mappings.get(x, 'No Sector Mapping'))
    sector_round_trips = sector_round_trips.dropna(axis=0)

    return sector_round_trips


# def gen_round_trip_stats(round_trips):
#     """Generate various round-trip statistics.
#
#     Parameters
#     ----------
#     round_trips : pd.DataFrame:
#         DataFrame with one row per-round-trip trade.
#         - See full explanation in round_trips.extract_round_trips
#
#     Returns
#     -------
#     stats : dict
#        A dictionary where each value is a pandas DataFrame containing
#        various round-trip statistics.
#
#     See also
#     --------
#     round_trips.print_round_trip_stats
#     """
#
#     stats = {'pnl': agg_all_long_short(round_trips, 'pnl', PNL_STATS), 'summary': agg_all_long_short(round_trips, 'pnl',
#                                                                                                      SUMMARY_STATS),
#              'duration': agg_all_long_short(round_trips, 'duration',
#                                             DURATION_STATS), 'returns': agg_all_long_short(round_trips, 'returns',
#                                                                                            RETURN_STATS),
#              'symbols': round_trips.groupby('symbol')['returns'].agg(RETURN_STATS).T}
#
#     return stats
def gen_round_trip_stats(round_trips):
    """Generate various round-trip statistics.

    Parameters
    ----------
    round_trips : pd.DataFrame:
        DataFrame with one row per-round-trip trade.
        - See full explanation in round_trips.extract_round_trips

    Returns
    -------
    stats : dict
       A dictionary where each value is a pandas DataFrame containing
       various round-trip statistics.

    See also
    --------
    round_trips.print_round_trip_stats
    """

    # Helper function to apply custom and built-in functions
    def apply_custom_and_built_in_funcs(grouped, stats_dict):
        # Separate custom functions from built-in functions
        custom_funcs = {k: v for k, v in stats_dict.items() if callable(v)}
        built_in_funcs = [v for k, v in stats_dict.items() if not callable(v)]

        # Apply custom functions manually
        custom_results = {}
        for func_name, func in custom_funcs.items():
            custom_results[func_name] = grouped.apply(func)
        custom_results = pd.DataFrame(custom_results)

        # Apply built-in functions
        built_in_results = grouped.agg(built_in_funcs)

        # Combine results
        return pd.concat([custom_results, built_in_results], axis=1)

    # Generate statistics for pnl, summary, duration, and returns
    stats = {
        'pnl': agg_all_long_short(round_trips, 'pnl', PNL_STATS),
        'summary': agg_all_long_short(round_trips, 'pnl', SUMMARY_STATS),
        'duration': agg_all_long_short(round_trips, 'duration', DURATION_STATS),
        'returns': agg_all_long_short(round_trips, 'returns', RETURN_STATS),
        'symbols': apply_custom_and_built_in_funcs(round_trips.groupby('symbol')['returns'], RETURN_STATS).T
    }

    return stats


def print_round_trip_stats(round_trips, hide_pos=False, run_flask_app=False):
    """Print various round-trip statistics. Tries to pretty-print tables
    with HTML output if run inside IPython NB.

    Parameters
    ----------
    round_trips : pd.DataFrame:
        DataFrame with one row per-round-trip trade.
        - See full explanation in round_trips.extract_round_trips
    hide_pos : bool, optional, default: False
        Whether to hide the position-based statistics.

    run_flask_app : bool, optional, default: False
        Whether to run a Flask app to serve the round-trip statistics.

    See also
    --------
    round_trips.gen_round_trip_stats
    """

    stats = gen_round_trip_stats(round_trips)

    print_table(stats['summary'],
                float_format='{:.2f}'.format,
                name='Summary stats',
                run_flask_app=run_flask_app)
    print_table(stats['pnl'],
                float_format='${:.2f}'.format,
                name='PnL stats',
                run_flask_app=run_flask_app)
    print_table(stats['duration'],
                float_format='{:.2f}'.format,
                name='Duration stats',
                run_flask_app=run_flask_app)
    print_table(stats['returns'] * 100,
                float_format='{:.2f}%'.format,
                name='Return stats',
                run_flask_app=run_flask_app)

    if not hide_pos:
        stats['symbols'].columns = stats['symbols'].columns.map(format_asset)
        print_table(stats['symbols'] * 100,
                    float_format='{:.2f}%'.format,
                    name='Symbol stats',
                    run_flask_app=run_flask_app)
